import sys
sys.path.append("..")
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from sklearn.metrics.pairwise import rbf_kernel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from torchvision import transforms
import lpips

import io
import base64
import time
import re

from keyword_generator import keyword_generator
from rpg_save_img import RegionalGenerator
import json
import os

device = "cuda:1"
print(f"Using device: {device}")
model_id = "stabilityai/stable-diffusion-2-1"

# Load the VAE from Stable Diffusion 2.1
vae = AutoencoderKL.from_pretrained(
    model_id,
    subfolder="vae",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
vae = vae.to(device)

tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
    model_id,
    subfolder="text_encoder",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    # torch_dtype=torch.float32
)

text_encoder = text_encoder.to(device)
pipe = RegionalGenerator(model_id, dtype=torch.float16)
negative_prompt = "worst quality, low quality, medium quality, deleted, lowres, comic, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"

lpips_model = lpips.LPIPS(net='alex')  # Options: 'alex', 'vgg', 'squeeze'


def centering(K):
    n = K.shape[0]
    H = np.eye(n) - np.ones((n, n)) / n
    return H @ K @ H


def rbf_CKA(X, Y, sigma=None):
    K = rbf_kernel(X, gamma=sigma)
    L = rbf_kernel(Y, gamma=sigma)

    K_centered = centering(K)
    L_centered = centering(L)

    # Compute HSIC
    HSIC = np.sum(K_centered * L_centered)

    # Normalize
    HSIC_XX = np.sum(K_centered * K_centered)
    HSIC_YY = np.sum(L_centered * L_centered)

    # Return CKA
    return HSIC / np.sqrt(HSIC_XX * HSIC_YY)



def get_text_embeddings(text, device=None):
    if device is None:
        device = device

    text_input = tokenizer(
        text,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    )

    text_input = {k: v.to(device) for k, v in text_input.items()}
    with torch.no_grad():
        text_embeddings = text_encoder(**text_input)[0]

    return text_embeddings


def convert_pil_to_base64(pil_image):
    buffer = io.BytesIO()
    pil_image.save(buffer, format='JPEG')  # Saves to memory buffer, not disk
    return base64.b64encode(buffer.getvalue()).decode('utf-8')


def gpt_response(img, caption, keyword, max_attempts=10, retry_delay=5):
    img = convert_pil_to_base64(img)

    prompt = f'''
    Given the image and caption, first detect the two most important objects in the image. Then, describe each of the object using the keyword: {keyword} as follow with 2 sentences. For each sentence, use 5-10 words.

    ### The Man
    1. The man is in front of the van. He is beside the sidewalk edge near the street.
    

    ### The Suitcase
    1. The suitcase is behind the man. It is on the curb of the sidewalk.
    
    Now, given the image I uploaded and the caption "{caption}", describe the two most important objects using the keyword {keyword} with exactly the example format:
    '''

    model = ChatOpenAI(model="gpt-4o-mini",
                       openai_api_key="xxx",
                       temperature=0,
                       max_tokens=None,
                       timeout=None,
                       max_retries=2)

    message = HumanMessage(
        content=[
            {"type": "text", "text": prompt},
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{img}"},
            },
        ],
    )

    attempts = 0

    while attempts < max_attempts:
        attempts += 1
        try:
            response = model.invoke([message])
            if response and response.content and len(response.content.strip()) > 0:
                return response.content
            else:
                print(f"Empty response received on attempt {attempts}. Retrying...")
        except Exception as e:
            print(f"Error on attempt {attempts}: {str(e)}")

        # Only sleep if we're going to retry
        if attempts < max_attempts:
            print(f"Waiting {retry_delay} seconds before retry...")
            time.sleep(retry_delay)

    # If we've exhausted all attempts, raise an exception
    raise Exception(f"Failed to get response after {max_attempts} attempts")



def parse_response(input_text):
    sentences = re.findall(r'[^.]*\.', input_text)

    # Clean up sentences (remove numbering, extra spaces)
    clean_sentences = []
    for sentence in sentences:
        # Remove leading numbers, spaces, and other non-sentence content
        cleaned = sentence.strip()
        if re.match(r'^[0-9]+\.', cleaned):  # If it starts with numbers followed by period
            cleaned = re.sub(r'^[0-9]+\.', '', cleaned)
        cleaned = cleaned.strip()
        if cleaned and not cleaned.startswith('###'):  # Skip headers
            clean_sentences.append(cleaned)

    # Join all sentences with a space
    result = ' '.join(clean_sentences)
    return result

def parse_response_with_obj(text):
    # Extract objects and content
    pattern = r"### (.*?)\n(.*?)(?=\n\n###|\Z)"
    matches = re.findall(pattern, text, re.DOTALL)

    # Create dictionaries to store the results
    objects = []
    contents = []

    for match in matches:
        obj = match[0].strip()
        # Extract content by removing the list number and whitespace
        content = re.sub(r"^\d+\.\s+", "", match[1].strip())

        objects.append(obj)
        contents.append(content)

    # print("Objects:", objects)
    # print("Contents:", contents)
    return objects, contents




def process_item(cur_caption, cur_img, keyword):
    response = gpt_response(cur_img, cur_caption, keyword)
    parsed_response = parse_response(response)
    return parsed_response, response

import concurrent.futures
def get_data_reponses(caption_list, img_list, keyword):
    assert len(caption_list) == len(img_list)
    response_list = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(process_item, caption_list[i], img_list[i], keyword) for i in range(len(caption_list))]
        response_list = [future.result()[0] for future in futures]
        non_parsed_response_list = [future.result()[1] for future in futures]
    return response_list, non_parsed_response_list

# def get_data_reponses(caption_list, img_list, keyword):
#     assert len(caption_list) == len(img_list)
#     response_list = []

#     for i in range(len(caption_list)):
#         cur_caption, cur_img = caption_list[i], img_list[i]
#         response = gpt_response(cur_img, cur_caption, keyword)
#         response = parse_response(response)
#         response_list.append(response)
#     return response_list

def cal_lpip_pix(obj_des_list, gt_img_list):
    # cal lpip
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    image_tensors = []
    gt_img_tensors = []
    for i in range(len(obj_des_list)):
        obj, content = obj_des_list[i]
        objs = obj[0] + " " + obj[1]
        prompt = [objs, content[0], content[1]]
        generated_img = pipe(prompt, negative_prompt,
                  batch_size = 1, #batch size
                  num_inference_steps=40, # sampling step
                  height = 512,
                  width = 512,
                  end_steps = 1, # The number of steps to end the attention double version (specified in a ratio of 0-1. If it is 1, attention double version will be applied in all steps, with 0 being the normal generation)
                  base_ratio=0.4, # Base ratio, the weight of base prompt, if 0, all are regional prompts, if 1, all are base prompts
                  seed = 42)
        generated_img = transform(generated_img[0])
        image_tensors.append(generated_img)
        gt_img = transform(gt_img_list[i])
        gt_img_tensors.append(gt_img)
    image_batch = torch.stack(image_tensors)  # (N, C, H, W)
    gt_img_batch = torch.stack(gt_img_tensors)  # (N, C, H, W)

    batch_size = image_batch.shape[0]
    lpips_scores = []

    for i in range(batch_size):
        score = lpips_model(image_batch[i].unsqueeze(0), gt_img_batch[i].unsqueeze(0))
        lpips_scores.append(score.item())

    lpips_score = sum(lpips_scores) / batch_size

    # pixcorr
    preprocess = transforms.Compose([
        transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
    ])

    all_images_flattened = preprocess(image_batch).reshape(len(image_batch), -1).cpu()
    all_images_gt_flattened = preprocess(gt_img_batch).view(len(gt_img_batch), -1).cpu()

    print(all_images_flattened.shape)
    print(all_images_gt_flattened.shape)

    corrsum = 0
    for i in range(len(all_images_flattened)):
        corrsum += np.corrcoef(all_images_flattened[i], all_images_gt_flattened[i])[0][1]
    corrmean = corrsum / len(image_batch)

    pixcorr = corrmean
    print("pixcorr: ", pixcorr)

    return lpips_score, pixcorr






def cal_score(fmri, captions, imgs, keyword):
    print("Calculating score for keyword:", keyword)
    response_list, non_parsed_response_list = get_data_reponses(captions, imgs, keyword)
    for i in range(0, len(response_list)):
        response_file = f"response/response{i}.json"
        os.makedirs(os.path.dirname(response_file), exist_ok=True)
        if os.path.exists(response_file):
            with open(response_file, "r") as f:
                response_dict = json.load(f)
        else:
            response_dict = {}
        response_dict[keyword] = response_list[i]
        with open(response_file, "w") as f:
            json.dump(response_dict, f, indent=4)

    obj_des_list = []
    for i in range(len(non_parsed_response_list)):
        obj, content = parse_response_with_obj(non_parsed_response_list[i])
        obj_des_list.append([obj, content])
    # print(len(obj_des_list))

    lpips_score, pixcorr = cal_lpip_pix(obj_des_list, imgs)

    lpip_score = 1 - lpips_score
    print("LPIPS Score: ", lpip_score)
    text_emb_list = []
    for i in response_list:
        text_emb = get_text_embeddings(i, device)
        text_emb_list.append(text_emb)
    text_emb_list = torch.stack(text_emb_list)
    text_emb_list = text_emb_list.reshape(text_emb_list.size(0), -1)
    text_emb_list = text_emb_list.cpu().numpy()
    cka = rbf_CKA(fmri, text_emb_list)

    print("CKA Score: ", cka)
    return cka, lpip_score, pixcorr


def load_data():
    data_dir = "train_data"
    data = np.load(f"{data_dir}/data.npz", allow_pickle=True)
    train_fmri = data['train_fmri']
    train_textemb = data['train_textemb']
    caption_list_train = data['caption_list_train']
    coco_id_train = data['coco_id_list_train']
    img_list_train = data['img_list_train']
    # print(len(img_list_train))
    return train_fmri, img_list_train, caption_list_train



### Search for the keyword
if __name__ == "__main__":
    train_fmri, train_img_list, train_caption_list = load_data()
    assert len(train_fmri) == len(train_img_list) == len(train_caption_list)
    
    imgs = train_img_list
    captions = train_caption_list
    fmris = train_fmri

    ### Test the keyword
    # keyword = "Geometric Relation"
    # score = cal_score(fmris, captions, imgs, keyword)
    # print("Score: ", score)

    ### Define the reference keyword nums and expanding nums
    REF_NUM = 8
    EXP_NUM = 2

    ### Define the starting keywords
    # initial_relation_names = [
    #     "Locational Relations",
    #     "Structural Relations",
    #     "Conceptual Relations",
    #     "Orientation Relations",
    #     "Inclusion Relations",
    #     "Structural Relations"
    # ]

    initial_relation_names = [
        "Spatial Relations",
        "Geometric Relations",
        "Functional Relations",
        "Semantic Relations",
        "Directional Relations",
        "Containment Relations",
        "Support Relations"
    ]

    ## Initialize the keywords and scores
    keywords_dict = {}
    for keyword in initial_relation_names:
        keywords_dict[keyword] = cal_score(fmris, captions, imgs, keyword)
        print(f"{keyword}: {keywords_dict[keyword]}")

    ## Or read from a json file
    # with open("keywords.json", "r") as f:
    #     keywords_dict = json.load(f)


    # with open("keywords_backup.json", "r") as f:
    #     keywords_dict = json.load(f)

    ### Start the keyword generation process
    generator = keyword_generator()

    for i in range(20):
        # sample some relation names
        # Sort keywords by score
        sorted_keywords = sorted(keywords_dict.keys(), key=lambda x: keywords_dict[x], reverse=True)
        
        # Apply temperature-based sampling
        temperature = 0.7
        scores = np.array([keywords_dict[k] for k in sorted_keywords])
        scores = np.exp(scores / temperature)  # Apply softmax with temperature
        probabilities = scores / np.sum(scores)
        chosen_indices = np.random.choice(
            len(sorted_keywords), 
            size=min(REF_NUM, len(sorted_keywords)), 
            replace=False, 
            p=probabilities
        )
        chosen_relation_names = [sorted_keywords[i] for i in chosen_indices]

        print(f"Attempt {i+1} to generate keywords from {chosen_relation_names}.")
        new_keywords = generator.generate_key_word(chosen_relation_names, gen_num=EXP_NUM)
        for keyword in new_keywords:
            # Check if the keyword is not already in the dictionary
            if keyword not in keywords_dict:
                # Evaluate the new keyword and add it to the dictionary
                # keywords_dict[keyword] = cal_score(fmris, captions, imgs, keyword)
                scores = cal_score(fmris, captions, imgs, keyword)
                keywords_dict[keyword] = scores[0]

        ### Print the final keywords and their scores
        print("Keywords at attempt {i+1}:")
        for keyword, score in keywords_dict.items():
            print(f"{keyword}: {score}")

        ### Save the keywords to json file
        with open("keywords.json", "w") as f:
            json.dump(keywords_dict, f, indent=4)

    # target_relation_names = [
    #     "Spatial Relations",
    #     "Geometric Relations",
    #     "Functional Relations",
    #     "Semantic Relations",
    #     "Directional Relations",
    #     "Containment Relations",
    #     "Support Relations"
    #     ]

    # common_relations = set(keywords_dict.keys()) & set(target_relation_names)
    # print(f"Number of common elements: {len(common_relations)}")
    # print(f"Common elements: {common_relations}")